import os
import sys
import gzip
from collections import defaultdict, Counter
import pybedtools
from scipy.special import bdtrc, chdtrc
from scipy.stats import pearsonr, spearmanr

import pysam
from Bio import SeqIO
from pylab import *


halfwindow = 50

def read_gene_expression():
    filename = "genes.FANTOM_CAT.THP-1.counts.txt"
    stream = open(filename)
    line = next(stream)
    words = line.split()
    samples = words[1:]
    assert words[0] == "gene"
    genes = []
    counts = []
    for line in stream:
        words = line.split()
        gene = words[0]
        row = array(words[1:], int)
        counts.append(row)
        genes.append(gene)
    stream.close()
    counts = array(counts)
    counts = 1.e6 * counts / sum(counts, 0)
    timepoints = ("00_hr", "01_hr", "04_hr", "12_hr", "24_hr", "96_hr")
    averaged_counts = zeros((len(genes), len(timepoints)))
    n_samples = zeros(len(timepoints))
    for index, sample in enumerate(samples):
        timepoint, replicate = sample.rsplit("_", 1)
        j = timepoints.index(timepoint)
        n_samples[j] += 1
        averaged_counts[:, j] += counts[:, index]
    averaged_counts /= n_samples
    counts = {gene: averaged_counts[index] for index, gene in enumerate(genes)}
    return counts

def parse_alignments(path):
    header = """\
psLayout version 3

match	mis- 	rep. 	N's	Q gap	Q gap	T gap	T gap	strand	Q        	Q   	Q    	Q  	T        	T   	T    	T  	block	blockSizes 	qStarts	 tStarts
     	match	match	   	count	bases	count	bases	      	name     	size	start	end	name     	size	start	end	count
---------------------------------------------------------------------------------------------------------------------------------------------------------------
"""
    print("Reading", path)
    handle = gzip.open(path, "rt")
    line1 = next(handle)
    line2 = next(handle)
    line3 = next(handle)
    line4 = next(handle)
    line5 = next(handle)
    assert line1 + line2 + line3 + line4 + line5 == header
    for line1 in handle:
        line2 = next(handle)
        words1 = line1.split()
        words2 = line2.split()
        strand1 = words1[8]
        strand2 = words2[8]
        assert strand1 == '+'
        assert strand2 == '-'
        qName1 = words1[9]
        qName2 = words2[9]
        assert qName1 == qName2
        tName1 = words1[13]
        tName2 = words2[13]
        assert tName1 == tName2
        tSize1 = int(words1[14])
        tSize2 = int(words2[14])
        assert tSize1 == tSize2
        tStart1 = int(words1[15])
        tStart2 = int(words2[15])
        tEnd1 = int(words1[16])
        tEnd2 = int(words2[16])
        yield (qName1, tName1, tSize1, tStart1, tEnd2)
    handle.close()
    
def read_spliceboundaries(targets):
    boundaries = {}
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    for target in targets:
        filename = "%s.psl" % target
        path = os.path.join(directory, filename)
        print("Reading splice boundaries from", path)
        handle = open(path)
        for line in handle:
            words = line.split()
            assert len(words) == 21
            qName = words[9]
            qSize = int(words[10])
            blockSizes = [int(blockSize) for blockSize in words[18].split(",")[:-1]]
            strand = words[8]
            if strand == '-':
                blockSizes.reverse()
            assert qName not in boundaries
            boundaries[qName] = cumsum(blockSizes)[:-1]
        handle.close()
    return boundaries

def calculate_enrichment(alignments, boundaries, genes):
    tLengths = []
    multimap = Counter()
    genecount = Counter()
    tLengthMax = 0
    for alignment in alignments:
        qName, tName, tSize, tStart, tEnd = alignment
        if tEnd - tStart > 0.9 * tSize:  # almost full-length
            continue
        multimap[qName] += 1
        genecount[tName] += 1
        assert tStart < tEnd
        tLength = tEnd - tStart
        if tLength > tLengthMax:
            tLengthMax = tLength
    tLengthDistribution = zeros(tLengthMax+1, float)
    for alignment in alignments:
        qName, tName, tSize, tStart, tEnd = alignment
        if tEnd - tStart > 0.9 * tSize:  # almost full-length
            continue
        assert tStart < tEnd
        tLength = tEnd - tStart
        weight = 1.0 / (multimap[qName] * genecount[tName])
        tLengthDistribution[tLength] += weight
    tLengthDistribution /= sum(tLengthDistribution)
    tLengths = arange(0, tLengthMax+1)
    rnaSizes = defaultdict(list)
    contingency = defaultdict(lambda: zeros((2,2), float))
    genes_found = set()
    for alignment in alignments:
        qName, tName, tSize, tStart, tEnd = alignment
        if tEnd - tStart > 0.9 * tSize:  # almost full-length
            continue
        positions = boundaries.get(tName)
        if positions is None:
            continue
        gene = genes[tName]
        genes_found.add(gene)
        if len(positions) == 0:
            continue
        weight = 1.0 / multimap[qName]
        if tEnd in positions:
            tLength = tEnd - tStart
            rnaSizes[gene].append([tLength, weight])
            contingency[gene][0, 0] += weight
        else:
            contingency[gene][0, 1] += weight
        for tEnd in positions:
            tLength = tEnd - tStart
            if tLength < 0 or tLength > tLengthMax:
                continue
            contingency[gene][1, 0] += tLengthDistribution[tLength] * weight
    pvalues = {}
    for gene in contingency:
        foreground_positive, foreground_negative = contingency[gene][0, :]
        if foreground_positive == 0:
            pvalue = 1.0
        else:
            total = int(round(foreground_positive + foreground_negative))
            foreground_positive = int(round(foreground_positive))
            background_positive = contingency[gene][1, 0]
            p = background_positive / total
            pvalue = bdtrc(foreground_positive - 1, total, p)
        pvalues[gene] = pvalue
    return pvalues, rnaSizes, genes_found, contingency

def fill_contingency_table(alignments, boundaries, genes, contingency_tables, index):
    multimap = Counter()
    for alignment in alignments:
        qName, tName, tSize, tStart, tEnd = alignment
        if tEnd - tStart > 0.9 * tSize:  # almost full-length
            continue
        multimap[qName] += 1
    for alignment in alignments:
        qName, tName, tSize, tStart, tEnd = alignment
        if tEnd - tStart > 0.9 * tSize:  # almost full-length
            continue
        positions = boundaries.get(tName)
        if positions is None:
            continue
        if len(positions) == 0:
            continue
        gene = genes[tName]
        weight = 1.0 / multimap[qName]
        if tEnd in positions:
            contingency_tables[gene][0, index] += weight
        else:
            contingency_tables[gene][1, index] += weight


def calculate_profile(path, boundaries, genes, profile):
    d = defaultdict(dict)
    alignments = parse_alignments(path)
    alignments = list(alignments)
    for alignment in alignments:
        qName, tName, tSize, tStart, tEnd = alignment
        if tEnd - tStart > 0.9 * tSize:  # almost full-length
            continue
        positions = boundaries.get(tName)
        if positions is None:
            continue
        distances = tEnd - positions
        distances = distances[abs(distances) <= halfwindow]
        d[qName][tName] = distances
    total = Counter()
    for qName in d:
        weight = 1.0 / len(d[qName])
        for tName in d[qName]:
            gene = genes[tName]
            total[gene] += weight
            distances = d[qName][tName]
            n = len(distances)
            for distance in distances:
                profile[gene][distance+halfwindow] += weight / n
    return profile

def read_genes(targets):
    genes = {}
    coding_genes = set()
    noncoding_genes = set()
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    for target in targets:
        if target == "mRNA":
            category_genes = coding_genes
        elif target == "lncRNA":
            category_genes = noncoding_genes
        else:
            raise Exception("Unknown target %s" % target)
        filename = "%s.gff" % target
        path = os.path.join(directory, filename)
        print("Reading", path)
        lines = pybedtools.BedTool(path)
        for line in lines:
            transcript = line.attrs['transcript']
            gene = line.attrs['gene']
            genes[transcript] = gene
            category_genes.add(gene)
    noncoding_genes = noncoding_genes.difference(coding_genes)
    return genes, coding_genes, noncoding_genes

f = figure(figsize=(5.0,4.0))

directory = "/osc-fs_home/mdehoon/Data/CASPARs/"

x = arange(-halfwindow, halfwindow+1)

gene_expression = read_gene_expression()

pvalues = {}
alpha = 0.05
targets = ("mRNA", "lncRNA")
genes, coding_genes, noncoding_genes = read_genes(targets)
boundaries = read_spliceboundaries(targets)
alignments = []
for j, target in enumerate(targets):
    subdirectory = os.path.join(directory, "MiSeq", "PSL")
    filenames = os.listdir(subdirectory)
    filenames.sort()
    for filename in filenames:
        if not filename.endswith(".psl.gz"):
            continue
        terms = filename.split(".")
        if terms[1] != target:
            continue
        library = terms[0]
        if not library.startswith("t"):
            continue  # include time course samples only
        path = os.path.join(subdirectory, filename)
        for alignment in parse_alignments(path):
            alignments.append(alignment)


pvalues, rnaSizes, genes_found, contingency = calculate_enrichment(alignments, boundaries, genes)


n = len(genes_found.intersection(coding_genes))
print("Number of coding genes with truncated reads: %d" % n)
n = len(genes_found.intersection(noncoding_genes))
print("Number of noncoding genes with truncated reads: %d" % n)

enriched = sum([pvalues[gene] < alpha for gene in pvalues if gene in coding_genes])
number = len([gene for gene in pvalues if gene in coding_genes])
percentage = 100.0 * enriched / number
print("Number of coding genes with enrichment at splice sites: %d out of %d (%.1f%%)" % (enriched, number, percentage))

enriched = sum([pvalues[gene] < alpha for gene in pvalues if gene in noncoding_genes])
number = len([gene for gene in pvalues if gene in noncoding_genes])
percentage = 100.0 * enriched / number
print("Number of noncoding genes with enrichment at splice sites: %d out of %d (%.1f%%)" % (enriched, number, percentage))

chisquare = -2 * sum([log(pvalue) for pvalue in pvalues.values()])
n = len(pvalues)
pvalue = chdtrc(2*n, chisquare)
print("Fisher-combined p-value = %.2f" % pvalue)

rnaSizeMax = 0
for gene in rnaSizes:
    for rnaSize, weight in rnaSizes[gene]:
        if rnaSize > rnaSizeMax:
            rnaSizeMax = rnaSize

rnaSizeDistribution_coding = zeros(rnaSizeMax+1)
rnaSizeDistribution_noncoding = zeros(rnaSizeMax+1)
for gene in rnaSizes:
    if gene in coding_genes:
        rnaSizeDistribution = rnaSizeDistribution_coding
    elif gene in noncoding_genes:
        rnaSizeDistribution = rnaSizeDistribution_noncoding
    else:
        raise Exception(gene)
    for rnaSize, weight in rnaSizes[gene]:
        rnaSizeDistribution[rnaSize] += weight

ax = f.add_subplot(211)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_ylabel("Coding genes", labelpad=25, fontsize=8)

ax = f.add_subplot(212)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_ylabel("Non-coding genes", labelpad=25, fontsize=8)

f.add_subplot(222)

bar(arange(401), rnaSizeDistribution_coding[:401], color='black', snap=False)
xlim(0,400)
xticks(fontsize=8)
yticks(fontsize=8)

f.add_subplot(224)

bar(arange(401), rnaSizeDistribution_noncoding[:401], color='black', snap=False)
xlim(0,400)
xticks(fontsize=8)
yticks(fontsize=8)
xlabel("Size of RNAs terminating\nat splice sites\n[nucleotides]", fontsize=8)

rnaSizeDistribution = rnaSizeDistribution_coding + rnaSizeDistribution_noncoding
rnaSizeDistribution = rnaSizeDistribution / sum(rnaSizeDistribution)
cumulative_distribution = cumsum(rnaSizeDistribution)
median_size = argmin(abs(cumulative_distribution-0.5))
print("Median size of RNAs terminating at splice sites: %d nt" % median_size)

contingency_tables = defaultdict(lambda: zeros((2, 6)))
for j, target in enumerate(targets):
    subdirectory = os.path.join(directory, "MiSeq", "PSL")
    timepoints = ("t00", "t01", "t04", "t12", "t24", "t96")
    libraries = []
    for filename in sorted(os.listdir(subdirectory)):
        if not filename.endswith(".psl.gz"):
            continue
        terms = filename.split(".")
        if terms[1] != target:
            continue
        library = terms[0]
        if not library.startswith("t"):
            continue  # include time course samples only
        libraries.append(library)
        timepoint, replicate = library.split("_")
        assert timepoint in timepoints
        assert replicate in ("r1", "r2", "r3")
    for index, timepoint in enumerate(timepoints):
        alignments = []
        for library in libraries:
            if not library.startswith(timepoint):
                continue
            filename = "%s.%s.psl.gz" % (library, target)
            path = os.path.join(subdirectory, filename)
            for alignment in parse_alignments(path):
                alignments.append(alignment)
        fill_contingency_table(alignments, boundaries, genes, contingency_tables, index)

fractions = {}
for gene in pvalues:
    if pvalues[gene] < 0.05:
        numerator = sum(contingency_tables[gene][0, :])
        denominator = sum(contingency_tables[gene])
        fraction = numerator / denominator
        fractions[gene] = fraction

average_fraction_coding = mean([fractions[gene] for gene in fractions if gene in coding_genes])
average_fraction_noncoding = mean([fractions[gene] for gene in fractions if gene in noncoding_genes])

print("Average percentage of short capped RNAs terminating at splice sites of coding genes, if significant: %.1f%%" % (100.0 * average_fraction_coding))
print("Average percentage of short capped RNAs terminating at splice sites of non-coding genes, if significant: %.1f%%" % (100.0 * average_fraction_noncoding))

profiles = defaultdict(lambda: zeros(2*halfwindow+1))
for j, target in enumerate(targets):
    subdirectory = os.path.join(directory, "MiSeq", "PSL")
    filenames = os.listdir(subdirectory)
    filenames.sort()
    for filename in filenames:
        if not filename.endswith(".psl.gz"):
            continue
        terms = filename.split(".")
        if terms[1] != target:
            continue
        library = terms[0]
        if not library.startswith("t"):
            continue  # include time course samples only
        path = os.path.join(subdirectory, filename)
        calculate_profile(path, boundaries, genes, profiles)


profiles = dict(profiles)
profile_coding = zeros(2*halfwindow+1)
profile_noncoding = zeros(2*halfwindow+1)
total_coding = 0
total_noncoding = 0

x = arange(-halfwindow, halfwindow+1)

for gene in profiles:
    n = sum(profiles[gene])
    profiles[gene] /= n
    if gene in coding_genes:
        profile_coding[:] += profiles[gene]
        total_coding += n
    elif gene in noncoding_genes:
        profile_noncoding[:] += profiles[gene]
        total_noncoding += n

profile_coding /= sum(profile_coding)
profile_noncoding /= sum(profile_noncoding)

profile_coding *= total_coding
profile_noncoding *= total_noncoding

f.add_subplot(221)
plot(x, profile_coding, color='black')

xticks(fontsize=8)
yticks(fontsize=8)
ylabel("Number of sequences", fontsize=8)

f.add_subplot(223)
plot(x, profile_noncoding, color='black')
xlabel("Position of the 3' end with respect\nto exon-exon boundaries\n[base pairs]", fontsize=8)

xticks(fontsize=8)
yticks(fontsize=8)
ylabel("Number of sequences", fontsize=8)

subplots_adjust(bottom=0.2, top=0.9, left=0.2, right=0.95, hspace=0.3, wspace=0.4)

filename = "figure_spliceboundary_timecourse.svg"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_spliceboundary_timecourse.png"
print("Saving figure to %s" % filename)
savefig(filename)
